from Model.InferenceModule.inference_module import InferenceModule
import numpy as np
from tianshou.data import Batch
from Model.InferenceModule.module_utils import apply_mask
from Network.network_utils import pytorch_model
from Model.InferenceModule.module_utils import trace_log_probs

class AllModule(InferenceModule):
    def __init__(self, args, extractor, dists, inter_model, all_model):
        super().__init__(args, extractor)
        self.mp = args.inter
        self.mask_args = args.inter.masking
        self.dist_settings = [self.mask_args.masking_form]
        self.name = "all"
        self.inter_model = inter_model
        self.model = all_model

        self.dists = dists
        self.forward_dist = self.dists.forward
        self.init_optimizer(args)
            
    def __call__(self, batch, valid, extractor, normalizer, additional=[], grad_settings=[], log_batch=[], full=False, probs=False, keep_all=False):
        # TODO: the logic is very similar to full_module, omit flags does not omit invalid because it differs per object
        # TODO: check omit logic because this probably causes issues
        omit_flags = self.get_omit(batch, keep_all=keep_all, keep_invalid=True) # there is no invalid for all
        batch = batch[omit_flags]
        key_state = batch.target
        query_state = batch.obs
        valid = valid[omit_flags]
        key_query_state = np.concatenate([key_state, query_state], axis=-1)

        result = Batch()
        if full:
            result.mask = pytorch_model.wrap(np.ones(extractor.num_objects), cuda=self.iscuda)
            result.mask_logits = pytorch_model.wrap(np.ones(extractor.num_objects), cuda=self.iscuda)
        else:
            result.mask_logits, info = self.inter_model(key_query_state, valid=valid, ret_settings=additional, grad_settings=grad_settings)
            result.mask_add = Batch()
            result.mask_add.mask_input, keys, queries, info = info
            for i, aname in enumerate(additional):
                result.mask_add[aname] = info[i]
        
        for k in log_batch:
            result[k] = batch[omit_flags][k]        
        result.omit_flags = omit_flags

        if probs: return result # if just need the probabilities, shortcut computation
        
        # run the model for each kind of masking, using the same logits. result.params, mask, info and target, dist, log_probs
        # are only applied to the one used by self.mask_setting
        all_settings = ['flat', 'soft', 'hard', 'mixed']
        for masking in all_settings:
            if masking in additional or masking == self.mask_args.masking_form:
                result[masking] = Batch() 
                mask = apply_mask(self.mask_args,self.dists, result.mask_logits, soft=masking=='soft' or masking == 'mixed', flat=masking=='flat', mixed=masking=='mixed', test=self.test, iscuda=self.iscuda)
                result[masking].params, result[masking].mask, info = self.model(key_query_state, m=mask, valid = valid, dist_settings=self.dist_settings, ret_settings=additional, grad_settings=grad_settings)
                result[masking].full_active_input, keys, queries, info1,info2 = info
                info = list(zip(info1,info2))
                result[masking].target, result[masking].log_probs = self._target_dists(batch, result[masking].params)
                result[masking].trace_log_probs = trace_log_probs(extractor.num_objects, result[masking].log_probs, batch)
            for i, aname in enumerate(additional):
                if masking in result:
                    result[masking][aname] = info[i]
            if masking == self.mask_args.masking_form: 
                for name in result[masking].keys():
                    result[name] = result[masking][name]
        return result